#include "rpc_codec.hpp"

#include <assert.h>
#include <boost/crc.hpp>

#include "../base/logger.h"
#include "../net/tcp_connection.hpp"
#include "google-inl.h"
#include "rpc.pb.h"

namespace
{
    int ProtobufVersionCheck()
    {
        GOOGLE_PROTOBUF_VERIFY_VERSION;
        return 0;
    }
    int dummy = ProtobufVersionCheck();


    const std::string kNoErrorStr = "NoError";
    const std::string kInvalidLengthStr = "InvalidLength";
    const std::string kCheckSumErrorStr = "CheckSumError";
    const std::string kInvalidNameLenStr = "InvalidNameLen";
    const std::string kUnknownMessageTypeStr = "UnknownMessageType";
    const std::string kParseErrorStr = "ParseError";
    const std::string kUnknownErrorStr = "UnknownError";
}

namespace easyasio
{
namespace net
{

void RpcCodec::send(const TcpConnectionPtr& conn, const RpcMessage& message)
{
    Buffer buf;
    buf.append("RPC0", 4);

    // code copied from MessageLite::SerializeToArray() and MessageLite::SerializePartialToArray().
    GOOGLE_DCHECK(message.IsInitialized()) << InitializationErrorMessage("serialize", message);

    int byte_size = message.ByteSize();
    buf.ensureWritableBytes(byte_size + kHeaderLen);

    uint8_t* start = reinterpret_cast<uint8_t*>(buf.beginWrite());
    uint8_t* end = message.SerializeWithCachedSizesToArray(start);
    if (end - start != byte_size)
    {
        ByteSizeConsistencyError(byte_size, message.ByteSize(), static_cast<int>(end - start));
    }
    buf.hasWritten(byte_size);


    mnb::crc_basic<32>  crc_ccitt1(0x1021, 0xFFFF, 0, false, false);
    crc_ccitt1.process_bytes(buf.peek(), buf.readableBytes());
    buf.appendInt32(crc_ccitt1.checksum());

    assert(buf.readableBytes() == kHeaderLen + byte_size + kHeaderLen);
    int32_t len = asio::detail::socket_ops::host_to_network_long(static_cast<int32_t>(buf.readableBytes()));
    buf.prepend(&len, sizeof len);

    conn->send(buf.retrieveAllAsString());
}

void RpcCodec::onMessage(const TcpConnectionPtr& conn, Buffer* buf)
{
    while (buf->readableBytes() >= kMinMessageLen + kHeaderLen)
    {
        const int32_t len = buf->peekInt32();
        if (len > kMaxMessageLen || len < kMinMessageLen)
        {
            errorCallback_(conn, buf, kInvalidLength);
            break;
        }
        else if (buf->readableBytes() >= static_cast<size_t>(len + kHeaderLen))
        {
            RpcMessage message;
            // FIXME: can we move deserialization & callback to other thread?
            ErrorCode errorCode = parse(buf->peek() + kHeaderLen, len, &message);
            if (errorCode == kNoError)
            {
                // FIXME: try { } catch (...) { }
                messageCallback_(conn, message);
                buf->retrieve(kHeaderLen + len);
            }
            else
            {
                errorCallback_(conn, buf, errorCode);
                break;
            }
        }
        else
        {
            break;
        }
    }
}

const std::string& RpcCodec::errorCodeToString(ErrorCode errorCode)
{
    switch (errorCode)
    {
    case kNoError:
        return kNoErrorStr;
    case kInvalidLength:
        return kInvalidLengthStr;
    case kCheckSumError:
        return kCheckSumErrorStr;
    case kInvalidNameLen:
        return kInvalidNameLenStr;
    case kUnknownMessageType:
        return kUnknownMessageTypeStr;
    case kParseError:
        return kParseErrorStr;
    default:
        return kUnknownErrorStr;
    }
}

void RpcCodec::defaultErrorCallback(const TcpConnectionPtr& conn, Buffer* buf, ErrorCode errorCode)
{
    LOG_INFO("ProtobufCodec::defaultErrorCallback - {}", errorCodeToString(errorCode));
    if (conn && conn->connected())
    {
        conn->shutdown();
    }
}

int32_t RpcCodec::asInt32(const char* buf)
{
    int32_t be32 = 0;
    ::memcpy(&be32, buf, sizeof(be32));
    return asio::detail::socket_ops::network_to_host_long(be32);
}

RpcCodec::ErrorCode RpcCodec::parse(const char* buf, int len, RpcMessage* message)
{
    ErrorCode error = kNoError;

    int32_t expectedCheckSum = asInt32(buf + len - kHeaderLen);

    mnb::crc_basic<32>  crc_ccitt1(0x1021, 0xFFFF, 0, false, false);
    crc_ccitt1.process_bytes(buf, len - kHeaderLen);

    if (crc_ccitt1.checksum() == expectedCheckSum)
    {
        if (memcmp(buf, "RPC0", kHeaderLen) == 0)
        {
            // parse from buffer
            const char* data = buf + kHeaderLen;
            int32_t dataLen = len - 2 * kHeaderLen;
            if (message->ParseFromArray(data, dataLen))
            {
                error = kNoError;
            }
            else
            {
                error = kParseError;
            }
        }
        else
        {
            error = kUnknownMessageType;
        }
    }
    else
    {
        error = kCheckSumError;
    }

    return error;
}

}
}

